{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "language_info": {
   "name": "python"
  },
  "gpuClass": "standard",
  "accelerator": "GPU"
 },
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# keras-hub installation"
   ],
   "metadata": {
    "id": "FKTVkreu3MGG"
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NtcTiJqG20LV",
    "outputId": "d03cd2ad-50ea-4553-8c91-255bcfed9065"
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "  Preparing metadata (setup.py) ... \u001B[?25l\u001B[?25hdone\n"
     ]
    }
   ],
   "source": [
    "!pip install -q git+https://github.com/keras-team/keras-hub.git tensorflow --upgrade"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Imports"
   ],
   "metadata": {
    "id": "ofxRv3CK3Pdp"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "import keras_hub\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import tensorflow_datasets as tfds"
   ],
   "metadata": {
    "id": "AHJ-nbcT3ATU"
   },
   "execution_count": 2,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "#Data"
   ],
   "metadata": {
    "id": "r86tfpfB3T6_"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# Load Dataset\n",
    "\n",
    "train_ds, valid_ds = tfds.load(\n",
    "    \"glue/sst2\",\n",
    "    split=[\"train\", \"validation\"],\n",
    "    batch_size=16,\n",
    ")\n",
    "\n",
    "\n",
    "def split_features(x):\n",
    "    # GLUE comes with dictonary data, we convert it to a uniform format\n",
    "    # (features, label), where features is a tuple consisting of all\n",
    "    # features.\n",
    "    features = x[\"sentence\"]\n",
    "    label = x[\"label\"]\n",
    "    return (features, label)\n",
    "\n",
    "\n",
    "train_ds = train_ds.map(\n",
    "    split_features, num_parallel_calls=tf.data.AUTOTUNE\n",
    ").prefetch(tf.data.AUTOTUNE)\n",
    "valid_ds = valid_ds.map(\n",
    "    split_features, num_parallel_calls=tf.data.AUTOTUNE\n",
    ").prefetch(tf.data.AUTOTUNE)"
   ],
   "metadata": {
    "id": "8UDVK92k3VKQ"
   },
   "execution_count": 3,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "# Look first training set batch\n",
    "# The format is (string_tensor, label_tensor)\n",
    "train_ds.take(1).get_single_element()"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "w8cfKvOl3aiT",
    "outputId": "25c58ee3-88b2-4c07-9973-6bc0b966c875"
   },
   "execution_count": 4,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(<tf.Tensor: shape=(16,), dtype=string, numpy=\n",
       " array([b'for the uninitiated plays better on video with the sound ',\n",
       "        b'like a giant commercial for universal studios , where much of the action takes place ',\n",
       "        b'company once again dazzle and delight us ',\n",
       "        b\"'s no surprise that as a director washington demands and receives excellent performances , from himself and from newcomer derek luke \",\n",
       "        b', this cross-cultural soap opera is painfully formulaic and stilted . ',\n",
       "        b\", the film is n't nearly as downbeat as it sounds , but strikes a tone that 's alternately melancholic , hopeful and strangely funny . \",\n",
       "        b'only masochistic moviegoers need apply . ',\n",
       "        b'convince almost everyone that it was put on the screen , just for them ',\n",
       "        b\"like the english patient and the unbearable lightness of being , the hours is one of those reputedly `` unfilmable '' novels that has bucked the odds to emerge as an exquisite motion picture in its own right . \",\n",
       "        b'his supple understanding of the role ',\n",
       "        b'revelatory nor truly edgy -- merely crassly flamboyant and comedically labored . ',\n",
       "        b'teaming ', b\"'s not very well shot or composed or edited \",\n",
       "        b'of the startling intimacy ', b'draggy ',\n",
       "        b'admire this film for its harsh objectivity and refusal to seek our tears , our sympathies . '],\n",
       "       dtype=object)>,\n",
       " <tf.Tensor: shape=(16,), dtype=int64, numpy=array([0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1])>)"
      ]
     },
     "metadata": {},
     "execution_count": 4
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "#Model"
   ],
   "metadata": {
    "id": "CZTYpu0y4Y6A"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# Create the TextClassifier Model\n",
    "# For more details please look https://keras.io/guides/keras_hub/getting_started/\n",
    "\n",
    "classifier = keras_hub.models.BertTextClassifier.from_preset(\n",
    "    \"bert_tiny_en_uncased\", num_classes=2, dropout=0.1\n",
    ")\n",
    "\n",
    "# Add loss function, optimizer and metrics\n",
    "classifier.compile(\n",
    "    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    optimizer=keras.optimizers.experimental.AdamW(5e-5),\n",
    "    metrics=keras.metrics.SparseCategoricalAccuracy(),\n",
    "    jit_compile=True,\n",
    ")\n",
    "\n",
    "# To see the summary(layers) of the model we need to build it(or call it once) but we don't need\n",
    "# to do it if we just want to train it on a downstream task.\n",
    "classifier.layers[-1].build(\n",
    "    [\n",
    "        None,\n",
    "    ]\n",
    ")\n",
    "classifier.summary()"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1KxlC7gY32MH",
    "outputId": "b02c8c90-9f37-40b8-843b-5fcc66819e80"
   },
   "execution_count": 5,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Model: \"bert_classifier\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " padding_mask (InputLayer)      [(None, None)]       0           []                               \n",
      "                                                                                                  \n",
      " segment_ids (InputLayer)       [(None, None)]       0           []                               \n",
      "                                                                                                  \n",
      " token_ids (InputLayer)         [(None, None)]       0           []                               \n",
      "                                                                                                  \n",
      " bert_backbone (BertBackbone)   {'sequence_output':  4385920     ['padding_mask[0][0]',           \n",
      "                                 (None, None, 128),               'segment_ids[0][0]',            \n",
      "                                 'pooled_output': (               'token_ids[0][0]']              \n",
      "                                None, 128)}                                                       \n",
      "                                                                                                  \n",
      " dropout (Dropout)              (None, 128)          0           ['bert_backbone[0][0]']          \n",
      "                                                                                                  \n",
      " logits (Dense)                 (None, 2)            258         ['dropout[0][0]']                \n",
      "                                                                                                  \n",
      " bert_preprocessor (BertPreproc  multiple            0           []                               \n",
      " essor)                                                                                           \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 4,386,178\n",
      "Trainable params: 4,386,178\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "# Train the model\n",
    "\n",
    "N_EPOCHS = 2\n",
    "classifier.fit(\n",
    "    train_ds,\n",
    "    validation_data=valid_ds,\n",
    "    epochs=N_EPOCHS,\n",
    ")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "fxVfW7t14wRS",
    "outputId": "b9a4277d-4fa1-4668-9373-d8cef77bcc6a"
   },
   "execution_count": 6,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.\n",
      "Instructions for updating:\n",
      "Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Epoch 1/2\n",
      "4210/4210 [==============================] - 828s 191ms/step - loss: 0.3782 - sparse_categorical_accuracy: 0.8299 - val_loss: 0.4344 - val_sparse_categorical_accuracy: 0.8165\n",
      "Epoch 2/2\n",
      "4210/4210 [==============================] - 783s 186ms/step - loss: 0.2409 - sparse_categorical_accuracy: 0.9039 - val_loss: 0.4626 - val_sparse_categorical_accuracy: 0.8222\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f2390120df0>"
      ]
     },
     "metadata": {},
     "execution_count": 6
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "# Save the weights\n",
    "\n",
    "model_name = f\"bert_tiny_uncased_en_sst2-epochs={N_EPOCHS}.h5\"\n",
    "classifier.save_weights(f\"/content/{model_name}\")"
   ],
   "metadata": {
    "id": "QExT3HYu6g6h"
   },
   "execution_count": 7,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "# Load the weigths\n",
    "\n",
    "weights_path = f\"bert_tiny_uncased_en_sst2-epochs={N_EPOCHS}.h5\"\n",
    "classifier.load_weights(weights_path)"
   ],
   "metadata": {
    "id": "8Y0ZTKRj62nq"
   },
   "execution_count": 8,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "#Test the model"
   ],
   "metadata": {
    "id": "XWUIUhw47drP"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# We will shuffle the valid dataset and take 1st example\n",
    "\n",
    "shuffled_valid_ds = valid_ds.shuffle(55).rebatch(1)\n",
    "element = shuffled_valid_ds.take(1).get_single_element()\n",
    "\n",
    "pred_logits = classifier.predict(element[0])\n",
    "pred = tf.argmax(pred_logits)\n",
    "\n",
    "print(\n",
    "    f\"Text :: {element[0].numpy()[0].decode('utf-8')} \\nLabel :: {element[1].numpy()[0]} \"\n",
    "    f\"\\nModel Prediction :: {tf.argmax(pred_logits, axis=1).numpy()[0]}\"\n",
    ")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "FoQk4cy27jSk",
    "outputId": "aff204c9-4879-403b-91ce-afcc3679b103"
   },
   "execution_count": 9,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "1/1 [==============================] - 3s 3s/step\n",
      "Text :: preaches to two completely different choirs at the same time , which is a pretty amazing accomplishment .  \n",
      "Label :: 1 \n",
      "Model Prediction :: 1\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "# You can test the model with your own statement!\n",
    "\n",
    "label_dict = {0: \"Bad Statement\", 1: \"Good Statement\"}  # As for this dataset\n",
    "\n",
    "output = tf.argmax(\n",
    "    classifier.predict([\"Lord of the rings is best\"]), axis=1\n",
    ").numpy()[0]\n",
    "print(label_dict[output])"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "7MB2U6Bh_j_m",
    "outputId": "37a8223a-4178-4176-e660-258ed9b856c7"
   },
   "execution_count": 10,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "1/1 [==============================] - 1s 902ms/step\n",
      "Good Statement\n"
     ]
    }
   ]
  }
 ]
}
